"""
Harmonic analysis with shoter time windows. 
This Python script is run in NASA HECC and 254 jobs are submitted - one for each altimeter track. 

Author: Badarvada Yadidya
Created on: August 2023

"""

import warnings
warnings.filterwarnings('ignore')
import numpy as np
import xarray as xr
import pandas as pd
import sys
import os
import os.path as op
import datetime
import utide as ut
import zarr
import dask

ds = xr.open_zarr('/nobackup/ybadarva/diurnal/altim_nt_d.zarr')
ds_old = xr.open_zarr('/nobackup/ybadarva/diurnal/altim_nt.zarr')

ds2 = xr.open_dataset('/home1/ybadarva/nlyss_info.nc')
lat = ds2.lat.T.interpolate_na(dim='np').interpolate_na(dim='nt')
lat = lat.ffill('nt').bfill('nt').ffill('np').bfill('np')
lat = lat.where((lat < -1) | (lat > 1), 1)

lon = ds2.lon.T
lon = lon.where(lon <= 360, lon - 360)
lon = lon.where(lon >= 0, lon + 360)

altim = xr.open_dataset('/nobackup/ybadarva/arinne/finalized/Step03/outputs/nlyss_altim.nc')
sla = (altim.sla * 100).transpose('nc','np','nt')
time = altim.time.T

def detrend_dim(da, dim, deg=1):
    # detrend along a single dimension
    p = da.polyfit(dim=dim, deg=deg)
    fit = xr.polyval(da[dim], p.polyfit_coefficients)  # Use actual values of dimension 'dim'
    return da - fit

def datetime_conv(matlab_dn):
    if np.isnan(matlab_dn):
        return pd.NaT
    else:
        python_datetime = datetime.datetime.fromordinal(int(matlab_dn)) + datetime.timedelta(days=matlab_dn%1) - datetime.timedelta(days = 366)
        return python_datetime

min_start_time = pd.Timestamp('2017-01-01 10:00')
max_end_time = pd.Timestamp('2019-12-31 23:00')
constit = ['M2','S2','N2','K1','O1']
constit_sets = [['M2','S2'], ['M2','S2','N2'], ['K1','O1'], ['M2','S2','K1','O1'],['M2','S2','N2','K1','O1']]

n = int(sys.argv[1])
tssh = ds_old.tssh

# Here 'sssh' is fssh600 and 'fssh' is 'fssh900'. Not changing them because I'll have to modify most of the script.
sssh = ((tssh[:,:,n]-ds_old.fssh[:,:,n]) * (100/9.8)).compute() 
fssh = ((tssh[:,:,n]-ds.fssh600[:,:,n]) * (100/9.8)).compute() 

set_const = ['M2S2','M2S2N2','K1O1','M2S2K1O1','M2S2N2K1O1']

path = f'/nobackup/ybadarva/diurnal/final2/final_sliding_{n}.nc'

for m in range(ds.dims['np']):
# for m in range(int(sys.argv[2]),3305,1):    
# for m in range(5):    
    print(m)
    ste_sh = np.full((8,5,136,1,1), np.nan)
    fil_sh = np.full((8,5,136,1,1), np.nan)
    # Iterate over the time array
    t1 = xr.DataArray([datetime_conv(dn) for dn in time[:,m,n].values], dims='nc')
    mask = np.isnan(time[:,m,n].values) | np.isnan(sla[:,m,n].values)
    t1[mask]=np.nan
    mask = (t1.dt.year < 2017) | (t1.dt.year > 2019)
    t1 = t1.where(~mask, np.datetime64('NaT'))
    sla_new = np.full((8,136,1,1), np.nan)
    for h in range(8):
        sla_new[h,:,0,0] = sla[:,m,n].where(~mask, np.nan)

    time_periods = [15, 30, 60, 90, 120, 150, 180, 210]  # days
    for i, t in enumerate(t1):
        # If the time value is not NaN
        if pd.notnull(t):
            # Select the nearest time in 'second_dataset'
            nearest_time = ds['time'].sel(time=t, method='nearest').values                     
            time_ranges = {}  # to hold the resulting time ranges
            
            for days in time_periods:
                hours = days * 24
                time_range = pd.date_range(
                    start=nearest_time - pd.Timedelta(hours/2, 'H'),
                    end=nearest_time + pd.Timedelta(hours/2, 'H'),
                    freq='H')
            
                # If time_range starts before min_start_time
                if time_range[0] < min_start_time:
                    time_range = pd.date_range(
                        start=min_start_time,
                        periods=hours+1,
                        freq='H')
                
                # If time_range ends after max_end_time
                if time_range[-1] > max_end_time:
                    time_range = pd.date_range(
                        end=max_end_time,
                        periods=hours+1,
                        freq='H')
                
                # Save time range with key as days
                time_ranges[days] = time_range

            indices = range(len(time_periods))
            
            for idx, days in zip(indices, time_periods):
                sssh_days = detrend_dim((sssh[:,m].sel(time=time_ranges[days])), dim='time').compute()
                fssh_days = detrend_dim((fssh[:,m].sel(time=time_ranges[days])), dim='time').compute()
                if np.all(np.isnan(sssh_days)):
                    sla_new[idx,i,0,0] = np.nan
                    ste_sh[idx,:,i,0,0] = np.nan   
                    fil_sh[idx,:,i,0,0] = np.nan
                else:            
                    coef_s = ut.solve(sssh_days.time, sssh_days, lat=lat[m,n].values.item(), constit=constit, verbose=False)   
                    coef_f = ut.solve(fssh_days.time, fssh_days, lat=lat[m,n].values.item(), constit=constit, verbose=False)
            
                    for j, constit_set in enumerate(constit_sets):
                        ste_sh[idx,j,i,0,0] = ut.reconstruct(nearest_time, coef_s, constit=constit_set, verbose=False)['h']    
                        fil_sh[idx,j,i,0,0] = ut.reconstruct(nearest_time, coef_f, constit=constit_set, verbose=False)['h']                 

    ste_sh = xr.DataArray(ste_sh, dims=['days','nw',sla.dims[0],sla.dims[1],sla.dims[2]],
                          coords=dict(days=np.array(time_periods),nw=np.array(set_const),
                                      nc=sla.coords['nc'],np=np.array([m]),nt=np.array([n])),
                          name='fssh300')
    fil_sh = xr.DataArray(fil_sh, dims=['days','nw',sla.dims[0],sla.dims[1],sla.dims[2]],
                          coords=dict(days=np.array(time_periods),nw=np.array(set_const),
                                      nc=sla.coords['nc'],np=np.array([m]),nt=np.array([n])),
                          name='fssh600')
    sla_new = xr.DataArray(sla_new, dims=['days',sla.dims[0],sla.dims[1],sla.dims[2]],
                          coords=dict(days=np.array(time_periods),nc=sla.coords['nc'],np=np.array([m]),nt=np.array([n])),
                          name='sla') 

    ds_out = xr.merge([ste_sh,fil_sh,sla_new])#.to_dataset()

    if m == 0:
        ds_out.to_netcdf(path)
    else:
        ds_existing = xr.open_dataset(path)         
        ds_out2 = xr.concat([ds_existing, ds_out],dim='np')
        os.system(f'rm {path}')
        ds_out2.to_netcdf(path)
    
